import torch
import torchvision.models as models
model_dir = "models/model_best.pth.tar"
output = "resnet18.onnx"
device = torch.device("cpu")
num_classes = 2
print('=> running on device ' + str(device))
print('=> loading checkpoint:  ' + model_dir)
checkpoint = torch.load(model_dir,map_location=torch.device('cpu'))

# create the model architecture
print('=> using model:  resnet18')
model = models.resnet18(pretrained=True)

# reshape the model's output
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

# load the model weights
model.load_state_dict(checkpoint['state_dict'])

print('=> adding nn.Softmax layer to model')
model = torch.nn.Sequential(model, torch.nn.Softmax(1))

model.to(device)
model.eval()

# create example image data
resolution = checkpoint['resolution']
input = torch.ones((1, 3, resolution, resolution))
print('=> input size:  {:d}x{:d}'.format(resolution, resolution))

# export the model
input_names = [ "input_0" ]
output_names = [ "output_0" ]

print('=> exporting model to ONNX...')
torch.onnx.export(model, input, output, verbose=True, input_names=input_names, output_names=output_names)
print('=> model exported to:  {:s}'.format(output))